# Copyright Huawei Technologies Co., Ltd. 2025. All rights reserved.

import os
import argparse
import json
import shutil

from typing import Any, Dict, Union
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from ascend_utils.common.security.path import get_valid_read_path, get_valid_write_path


MAX_KEY_LENGTH = 256
MAX_JSON_LENGTH = 4096


class SafeGenerator:
    def __init__(self):
        pass

    @staticmethod
    def get_config_from_pretrained(model_path, **kwargs):
        model_path = get_valid_read_path(model_path, is_dir=True, check_user_stat=False)
        try:
            config = AutoConfig.from_pretrained(model_path, local_files_only=True, **kwargs)
        except EnvironmentError as env_err:
            raise EnvironmentError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {env_err}"
            ) from env_err
        except Exception as err:
            raise ValueError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {err}"
            ) from err
        return config

    @staticmethod
    def get_model_from_pretrained(model_path, **kwargs):
        model_path = get_valid_read_path(model_path, is_dir=True, check_user_stat=False)
        try:
            model = AutoModelForCausalLM.from_pretrained(model_path, local_files_only=True, **kwargs)
        except EnvironmentError as env_err:
            raise EnvironmentError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {env_err}"
            ) from env_err
        except Exception as err:
            raise ValueError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {err}"
            ) from err
        return model

    @staticmethod
    def get_tokenizer_from_pretrained(model_path, **kwargs):
        model_path = get_valid_read_path(model_path, is_dir=True, check_user_stat=False)
        try:
            tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True, **kwargs)
        except EnvironmentError as env_err:
            raise EnvironmentError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {env_err}"
            ) from env_err
        except Exception as err:
            raise ValueError(
                f"Get model from pretrained failed, please check model weights files in the model path. "
                f"If the file exists, make sure the folder's owner has execute permission."
                f"Original error: {err}"
            ) from err
        return tokenizer

    @staticmethod
    def copy_tokenizer_files(model_dir, dest_dir):
        model_dir = get_valid_read_path(model_dir, is_dir=True, check_user_stat=False)
        if os.path.exists(dest_dir):
            dest_dir = get_valid_write_path(dest_dir, is_dir=True)
        else:
            os.makedirs(dest_dir, exist_ok=True)
            dest_dir = get_valid_write_path(dest_dir, is_dir=True)
        filenames = os.listdir(model_dir)
        max_file_num = 1024
        if len(filenames) > max_file_num:
            raise argparse.ArgumentTypeError(f"The file num in dir is {len(filenames)}, "
                                            f"which exceeds the limit {max_file_num}.")
        for filename in filenames:
            need_move = False
            file_names = ['tokenizer', 'tokenization', 'special_token_map', 'generation', 'configuration']
            for f in file_names:
                if f in filename:
                    need_move = True
                    break
            if need_move:
                src_filepath = os.path.join(model_dir, filename)
                dest_filepath = os.path.join(dest_dir, filename)
                shutil.copyfile(src_filepath, dest_filepath)

    @staticmethod
    def modify_config(model_dir, dest_dir, torch_dtype, quantize_type, args=None):
        model_dir = get_valid_read_path(model_dir, is_dir=True, check_user_stat=False)
        src_config_filepath = os.path.join(model_dir, 'config.json')
        src_config_filepath = get_valid_read_path(src_config_filepath, check_user_stat=False)
        with os.fdopen(os.open(src_config_filepath, os.O_RDONLY, 0o600),
                'r', encoding='utf-8') as fr:
            data = json.load(fr)
        data['torch_dtype'] = str(torch_dtype).split(".")[1]
        data['quantize'] = quantize_type
        if args is not None:
            quantization_config = {
                # 当is_lowbit为True，open_outlier为False时，group_size生效
                'group_size': args.group_size if args.is_lowbit and not args.open_outlier else 0,
                'kv_quant_type': "C8" if args.use_kvcache_quant else None,
                "fa_quant_type": "FAQuant" if args.use_fa_quant else None,
                'w_bit': args.w_bit,
                'a_bit': args.a_bit,
                'dev_type': args.device_type,
                'fraction': args.fraction,
                'act_method': args.act_method,
                'co_sparse': args.co_sparse,
                'anti_method': args.anti_method,
                'disable_level': args.disable_level,
                'do_smooth': args.do_smooth,
                'use_sigma': args.use_sigma,
                'sigma_factor': args.sigma_factor,
                'is_lowbit': args.is_lowbit,
                'mm_tensor': False,
                'w_sym': args.w_sym,
                'open_outlier': args.open_outlier,
                'is_dynamic': args.is_dynamic
            }
            if args.use_reduce_quant:
                quantization_config.update({"reduce_quant_type": "per_channel"})
            data['quantization_config'] = quantization_config

        dest_dir = get_valid_write_path(dest_dir, is_dir=True)
        dest_config_filepath = os.path.join(dest_dir, 'config.json')
        with os.fdopen(os.open(dest_config_filepath, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600),
                'w', encoding='utf-8') as fw:
            json.dump(data, fw, indent=4)

    @staticmethod
    def load_jsonl(dataset_path, key_name='inputs_pretokenized'):
        dataset = []
        if dataset_path == "humaneval_x.jsonl":
            key_name = 'prompt'
        with os.fdopen(os.open(dataset_path, os.O_RDONLY, 0o600),
                    'r', encoding='utf-8') as file:
            lines = file.readlines()
            for line in lines:
                data = json.loads(line)
                text = data.get(key_name, line)
                dataset.append(text)
        return dataset
    

class ArgumentValidator:
    context = None

    def __init__(self, *args, allow_none: bool = False, **kwargs):
        self.allow_none = allow_none
        self.validation_pipeline = []
        self.create_validation_pipeline()

    def validate(self, value: Any) -> None:
        if value is None and self.allow_none:
            return
        for method in self.validation_pipeline:
            method(value)

    def add_validation_method(self, method, position: int = None, target_method=None):
        if position is not None:
            self.validation_pipeline.insert(position, method)
        elif target_method and target_method in self.validation_pipeline:
            target_index = self.validation_pipeline.index(target_method)
            self.validation_pipeline.insert(target_index + 1, method)
        else:
            self.validation_pipeline.append(method)

    def delete_validation_method(self, method=None, position: int = None):
        if position is not None:
            if 0 <= position < len(self.validation_pipeline):
                self.validation_pipeline.pop(position)
        elif method and method in self.validation_pipeline:
            self.validation_pipeline.remove(method)

    def create_validation_pipeline(self):  # 空方法，子类可以覆盖
        pass

    def _create_validation_pipeline(self, *methods):
        self.validation_pipeline.clear()
        self.validation_pipeline.extend(methods)


class ArgumentParser(argparse.ArgumentParser):
    def __init__(self, *args, **kwargs):
        self.argument_validators = {}
        super().__init__(*args, **kwargs)

    def add_argument(self, *args, validator: Union[ArgumentValidator, Dict[Any, ArgumentValidator]] = None,
                    **kwargs) -> argparse.Action:
        arguments = super().add_argument(*args, **kwargs)
        if validator is not None:
            self.argument_validators.update({arguments.dest: validator})
        return arguments

    def parse_args(self, args=None, namespace=None) -> argparse.Namespace:
        args_all = super().parse_args(args, namespace)
        ArgumentValidator.context = vars(args_all)
        for arg, value in vars(args_all).items():
            if arg in self.argument_validators:
                validator = self.argument_validators[arg]
                type_of_value = type(value)
                try:
                    if isinstance(validator, dict):
                        if type_of_value is list:
                            type_of_value_to_validate = type(value[0])
                        else:
                            type_of_value_to_validate = type_of_value
                        if type_of_value_to_validate in validator:
                            validator[type_of_value_to_validate].validate(value)
                        else:
                            raise argparse.ArgumentTypeError(f"Validation failed for argument '{arg}': \
                                type {type_of_value_to_validate} not supported")
                    else:
                        validator.validate(value)
                except argparse.ArgumentTypeError as e:
                    raise argparse.ArgumentTypeError(f"Validation failed for argument '{arg}': {e}")
        return args_all

    def update_argument(self, old_name: str, new_name: str = None, **kwargs) -> None:
        old_name = old_name.lstrip('-')
        if new_name:
            kwargs.update({'dest': new_name.lstrip('-')})
        for action in self._actions:
            if action.dest == old_name:
                for key, value in kwargs.items():
                    setattr(action, key, value)


class StringArgumentValidator(ArgumentValidator):
    def __init__(self, min_length: int = 0, max_length: int = float('inf'), allow_none: bool = False):
        super().__init__(allow_none=allow_none)
        self.min_length = min_length
        self.max_length = max_length

    @staticmethod
    def validate_type(value: str) -> None:
        if not isinstance(value, str):
            raise argparse.ArgumentTypeError("Value must be a string")

    def validate_length(self, value: str) -> None:
        if not (self.min_length <= len(value) <= self.max_length):
            raise argparse.ArgumentTypeError(f"String length must be between {self.min_length} and {self.max_length}")

    def create_validation_pipeline(self):
        super()._create_validation_pipeline(self.validate_type, self.validate_length)